#@Install
!pip install easyfsl
#Import
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm
from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average
#Preprocess and split dataset
image_size = 28
train_set = Omniglot(
root="./data",
background=True, #Train data
transform=transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
),
download=True,
)
test_set = Omniglot(
root="./data",
background=False, #Test data
transform=transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.Resize([int(image_size * 1.15), int(image_size * 1.15)]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
]
),
download=True,
)
Accuracy: 87.44%
#Prototypical Networks
class PrototypicalNetworks(nn.Module):
def __init__(self, backbone: nn.Module):
super(PrototypicalNetworks, self).__init__()
self.backbone = backbone
def forward(
self,
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
) -> torch.Tensor:
"""
Predict query labels using labeled support images.
"""
# Extract the features of support and query images
z_support = self.backbone.forward(support_images)
z_query = self.backbone.forward(query_images)
# Infer the number of different classes from the labels of the support set
n_way = len(torch.unique(support_labels))
# Prototype i is the mean of all instances of features corresponding to labels == i
z_proto = torch.cat(
[
z_support[torch.nonzero(support_labels == label)].mean(0)
for label in range(n_way)
]
)
# Compute the euclidean distance from queries to prototypes
dists = torch.cdist(z_query, z_proto)
# And here is the super complicated operation to transform those distances into classification scores!
scores = -dists
return scores
convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
print(convolutional_network)
model = PrototypicalNetworks(convolutional_network).cuda()
#@title Hyperparameters
N_WAY = 3 # Number of classes per task
N_SHOT = 5 # Number of images per class (support set)
N_QUERY = 10 # Number of images per class (query set)
N_EVALUATION_TASKS = 1000
# The sampler needs a dataset with a "get_labels" method. Check the code if you have any doubt!
test_set.get_labels = lambda: [instance[1] for instance in test_set._flat_character_images]
test_sampler = TaskSampler(
test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)
test_loader = DataLoader(
test_set,
batch_sampler=test_sampler,
num_workers=12,
pin_memory=True,
collate_fn=test_sampler.episodic_collate_fn,
)
(
example_support_images,
example_support_labels,
example_query_images,
example_query_labels,
example_class_ids,
) = next(iter(test_loader))
plot_images(example_support_images, "support images", images_per_row=N_SHOT)
plot_images(example_query_images, "query images", images_per_row=N_QUERY)
model.eval()
example_scores = model(
example_support_images.cuda(),
example_support_labels.cuda(),
example_query_images.cuda(),
).detach()
_, example_predicted_labels = torch.max(example_scores.data, 1)
print("Ground Truth / Predicted")
for i in range(len(example_query_labels)):
print(
f"{test_set._characters[example_class_ids[example_query_labels[i]]]} / {test_set._characters[example_class_ids[example_predicted_labels[i]]]}"
)
Ground Truth / Predicted Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Syriac_(Serto)/character03 / Syriac_(Serto)/character03 Manipuri/character03 / Tengwar/character02 Manipuri/character03 / Tengwar/character02 Manipuri/character03 / Tengwar/character02 Manipuri/character03 / Syriac_(Serto)/character03 Manipuri/character03 / Tengwar/character02 Manipuri/character03 / Tengwar/character02 Manipuri/character03 / Tengwar/character02 Manipuri/character03 / Manipuri/character03 Manipuri/character03 / Tengwar/character02 Manipuri/character03 / Manipuri/character03 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02 Tengwar/character02 / Tengwar/character02
def evaluate_on_one_task(
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
query_labels: torch.Tensor,
):
"""
Returns the number of correct predictions of query labels, and the total number of predictions.
"""
return (
torch.max(
model(support_images.cuda(), support_labels.cuda(), query_images.cuda())
.detach()
.data,
1,
)[1]
== query_labels.cuda()
).sum().item(), len(query_labels)
def evaluate(data_loader: DataLoader):
# We'll count everything and compute the ratio at the end
total_predictions = 0
correct_predictions = 0
# eval mode affects the behaviour of some layers (such as batch normalization or dropout)
# no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)
model.eval()
with torch.no_grad():
for episode_index, (
support_images,
support_labels,
query_images,
query_labels,
class_ids,
) in tqdm(enumerate(data_loader), total=len(data_loader)):
correct, total = evaluate_on_one_task(
support_images, support_labels, query_images, query_labels
)
total_predictions += total
correct_predictions += correct
print(
f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"
)
evaluate(test_loader)
100%|██████████| 1000/1000 [00:34<00:00, 28.77it/s]
Model tested on 1000 tasks. Accuracy: 86.85%
N_TRAINING_EPISODES = 40000
N_VALIDATION_TASKS = 1000
# train_set.labels = [instance[1] for instance in train_set._flat_character_images] #incorrect, ignore this line; the following line is correct.
train_set.get_labels = lambda: [instance[1] for instance in train_set._flat_character_images]
train_sampler = TaskSampler(
train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)
train_loader = DataLoader(
train_set,
batch_sampler=train_sampler,
num_workers=12,
pin_memory=True,
collate_fn=train_sampler.episodic_collate_fn,
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def fit(
support_images: torch.Tensor,
support_labels: torch.Tensor,
query_images: torch.Tensor,
query_labels: torch.Tensor,
) -> float:
optimizer.zero_grad()
classification_scores = model(
support_images.cuda(), support_labels.cuda(), query_images.cuda()
)
loss = criterion(classification_scores, query_labels.cuda())
loss.backward()
optimizer.step()
return loss.item()
# Train the model yourself with this cell
log_update_frequency = 10
all_loss = []
model.train()
with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
for episode_index, (
support_images,
support_labels,
query_images,
query_labels,
_,
) in tqdm_train:
loss_value = fit(support_images, support_labels, query_images, query_labels)
all_loss.append(loss_value)
if episode_index % log_update_frequency == 0:
tqdm_train.set_postfix(loss=sliding_average(all_loss, log_update_frequency))
100%|██████████| 40000/40000 [38:20<00:00, 17.39it/s, loss=0.119]
model.load_state_dict(torch.load("/content/drive/MyDrive/Projets/FewShot/resnet18_with_pretraining_3w_2s_40kt.pth", map_location="cuda"))
evaluate(test_loader)
100%|██████████| 1000/1000 [00:35<00:00, 28.13it/s]
Model tested on 1000 tasks. Accuracy: 98.83%